{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "74a9e78f-3e8a-4ee2-89fe-b3a3f4784b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cornac\n",
    "from cornac.data import Reader\n",
    "from cornac.datasets import netflix\n",
    "from cornac.eval_methods import RatioSplit\n",
    "from cornac.models import MF"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf6bb9a5-ffb5-4221-8122-9aa286af1d9c",
   "metadata": {},
   "source": [
    "## Train a base recommender model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "76a0c130-7dd7-4004-a613-5b123dcc75d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rating_threshold = 1.0\n",
      "exclude_unknowns = True\n",
      "---\n",
      "Training data:\n",
      "Number of users = 9986\n",
      "Number of items = 4921\n",
      "Number of ratings = 547022\n",
      "Max rating = 1.0\n",
      "Min rating = 1.0\n",
      "Global mean = 1.0\n",
      "---\n",
      "Test data:\n",
      "Number of users = 9986\n",
      "Number of items = 4921\n",
      "Number of ratings = 60747\n",
      "Number of unknown users = 0\n",
      "Number of unknown items = 0\n",
      "---\n",
      "Total users = 9986\n",
      "Total items = 4921\n",
      "\n",
      "[MF] Training started!\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "49afd52e202546a69dd9c4a245f6db80",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization finished!\n",
      "\n",
      "[MF] Evaluation started!\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "190f65f8aa9e4b6d9cd8a0e1805dfc53",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Ranking:   0%|          | 0/8233 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "TEST:\n",
      "...\n",
      "   |    AUC | Recall@20 | Train (s) | Test (s)\n",
      "-- + ------ + --------- + --------- + --------\n",
      "MF | 0.8530 |    0.0669 |    0.9060 |   6.7622\n",
      "\n"
     ]
    }
   ],
   "source": [
    "data = netflix.load_feedback(variant=\"small\", reader=Reader(bin_threshold=1.0))\n",
    "\n",
    "ratio_split = RatioSplit(\n",
    "    data=data,\n",
    "    test_size=0.1,\n",
    "    rating_threshold=1.0,\n",
    "    exclude_unknowns=True,\n",
    "    verbose=True,\n",
    "    seed=123,\n",
    ")\n",
    "\n",
    "mf = MF(\n",
    "    k=50,\n",
    "    max_iter=25, \n",
    "    learning_rate=0.01, \n",
    "    lambda_reg=0.02, \n",
    "    use_bias=False,\n",
    "    verbose=True,\n",
    "    seed=123,\n",
    ")\n",
    "\n",
    "auc = cornac.metrics.AUC()\n",
    "rec_20 = cornac.metrics.Recall(k=20)\n",
    "\n",
    "cornac.Experiment(\n",
    "    eval_method=ratio_split,\n",
    "    models=[mf],\n",
    "    metrics=[auc, rec_20],\n",
    "    user_based=True,\n",
    ").run()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b6707bc-a63a-4729-8e49-51bafca31723",
   "metadata": {},
   "source": [
    "## Test setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "46c5f959-f706-4406-85c9-a4acf2ac20c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 20\n",
    "N = 10000\n",
    "test_users = np.random.RandomState(123).choice(mf.user_ids, size=N)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68f6350e-ba2c-4112-87f4-efe780ca314f",
   "metadata": {},
   "source": [
    "### Time taken by the base model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6730a93f-729d-4bea-b5aa-8fb9d9cef867",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1min 14s, sys: 27.3 ms, total: 1min 14s\n",
      "Wall time: 1.56 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "mf_recs = []\n",
    "for uid in test_users:\n",
    "    mf_recs.append(mf.recommend(uid, k=K))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9808c778-9513-483f-bc76-87aa63541bbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_recall(retrieved_neighbors, true_neighbors):\n",
    "    total = 0\n",
    "    for retrieved, true in zip(retrieved_neighbors, true_neighbors):\n",
    "        total += len(set(retrieved) & set(true))\n",
    "    return total / (N * K)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a1ebae2-caeb-4379-8391-1a670fb7e837",
   "metadata": {},
   "source": [
    "## Test performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fa280b9d-ec04-41eb-9de2-acfb67fbeb80",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AnnoyANN\t\tIndexing=34ms\t\tRetrieval=589ms\t\tRecall=0.01299\n",
      "FaissANN\t\tIndexing=109ms\t\tRetrieval=905ms\t\tRecall=0.99938\n",
      "HNSWLibANN\t\tIndexing=91ms\t\tRetrieval=215ms\t\tRecall=0.99874\n",
      "ScaNNANN\t\tIndexing=1564ms\t\tRetrieval=479ms\t\tRecall=0.99997\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "from cornac.models import AnnoyANN\n",
    "from cornac.models import FaissANN\n",
    "from cornac.models import HNSWLibANN\n",
    "from cornac.models import ScaNNANN\n",
    "\n",
    "anns = [\n",
    "    AnnoyANN(model=mf, n_trees=20, search_k=500, seed=123, num_threads=-1),\n",
    "    FaissANN(model=mf, nlist=100, nprobe=50, use_gpu=False, seed=123, num_threads=-1),\n",
    "    HNSWLibANN(model=mf, M=16, ef_construction=100, ef=50, seed=123, num_threads=-1),\n",
    "    ScaNNANN(\n",
    "        model=mf,\n",
    "        partition_params={\"num_leaves\": 100, \"num_leaves_to_search\": 50},\n",
    "        score_params={\"dimensions_per_block\": 2, \"anisotropic_quantization_threshold\": 0.2}, \n",
    "        rescore_params={\"reordering_num_neighbors\": 100},\n",
    "        seed=123, num_threads=-1,\n",
    "    ),\n",
    "]\n",
    "\n",
    "for ann in anns:\n",
    "    s = time.time()\n",
    "    ann.build_index()\n",
    "    t1 = time.time() - s\n",
    "    \n",
    "    s = time.time()\n",
    "    ann_recs = []\n",
    "    for uid in test_users:\n",
    "        ann_recs.append(ann.recommend(uid, k=K))\n",
    "    t2 = time.time() - s\n",
    "\n",
    "    recall = compute_recall(ann_recs, mf_recs)\n",
    "    print(f\"{ann.name}\\t\\tIndexing={int(t1*1000)}ms\\t\\tRetrieval={int(t2*1000)}ms\\t\\tRecall={recall:.5f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "daf585e5-f84b-4dc5-99f2-d6a5bd05ede4",
   "metadata": {},
   "source": [
    "## Test save/load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1fec7413-482b-417b-a12f-362a93aefb85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AnnoyANN True\n",
      "FaissANN True\n",
      "HNSWLibANN True\n",
      "ScaNNANN True\n"
     ]
    }
   ],
   "source": [
    "for ann in anns:\n",
    "    saved_path = ann.save(\"save_dir\")\n",
    "    loaded_ann = ann.load(saved_path)\n",
    "    print(\n",
    "        ann.name, \n",
    "        np.array_equal(\n",
    "            ann.recommend_batch(test_users[:10], k=K), \n",
    "            loaded_ann.recommend_batch(test_users[:10], k=K)\n",
    "        )\n",
    "    ) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cornac",
   "language": "python",
   "name": "cornac"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
